#!/usr/bin/env python3
# -*-coding=utf-8-*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def start_train():
    # 解析参数
    tf.app.flags.DEFINE_string('ps_hosts', "", "")
    tf.app.flags.DEFINE_string('worker_hosts', "", "")
    tf.app.flags.DEFINE_string('job_name', "", "")
    tf.app.flags.DEFINE_integer('task_index', 0, "")
    FLAGS = tf.app.flags.FLAGS
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    job_name = FLAGS.job_name
    task_index = FLAGS.task_index
    # 构建cluster
    cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
    server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
    # 下载图片片训练数据、测试数据
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    # 参数服务器join
    if job_name == 'ps':
        server.join()
    elif job_name == 'worker':
        with tf.device(tf.train.replica_device_setter(
                worker_device='/job:worker/task:%d' % task_index,
                cluster=cluster)):
            # 构建模型
            # 图片片像素tensor
            x = tf.placeholder(tf.float32, [None, 784])
            # 权重tensor
            W = tf.Variable(tf.zeros([784, 10]))
            # 偏置量tensor
            b = tf.Variable(tf.zeros([10]))
            y = tf.nn.softmax(tf.matmul(x, W) + b)
            y_ = tf.placeholder(tf.float32, [None, 10])
            cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
            global_step = tf.Variable(0)
            train_op = tf.train.AdagradOptimizer(0.01).minimize(
                cross_entropy, global_step=global_step)
            saver = tf.train.Saver()
            summary_op = tf.summary.merge_all()
            init_op = tf.global_variables_initializer()
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
            # 模型构建结束
            # 创建supervisor管理worker进程
            sv = tf.train.Supervisor(is_chief=(task_index == 0),
                                     logdir='/tmp/tensorflow/train_logs',
                                     init_op=init_op,
                                     summary_op=summary_op,
                                     saver=saver,
                                     global_step=global_step,
                                     save_model_secs=600)
            with sv.managed_session(server.target) as sess:
                step = 0
                while not sv.should_stop() and step < 1000:
                  batch_xs, batch_ys = mnist.train.next_batch(100)
                  _, step = sess.run([train_op, global_step], feed_dict={x: batch_xs, y_: batch_ys})
            # 评估模型在测试数据集上的正确性(概率)
            print sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                y_: mnist.test.labels})
            sv.stop()


if __name__ == '__main__':
    start_train()
